# First, define the capital space and number of values for k
k_ss = 1.508
N = 10000

K = collect(range(start = 0.25 * k_ss, stop = 1.75 * k_ss, length =  N))

# Define the previous value function, current value function, policy function, and parameters
v_prev = zeros(N)
v = zeros(N)
g = zeros(N)
alpha = 0.3
beta = 0.6
delta = 0.75
tolerance = 1e-8
dist = 1.0

# Define the iterations for the value and policy functions
value_fns = []
policy_fns = []

# Next, iterate until you reach a certain tolerance between iterations
while dist > tolerance
    global dist, v_prev, v, g
    for i in 1:N
        # Calculate possible capital choices
        capital_choices = K[i] ^ alpha + (1 - delta) * K[i] .- K
        valid_choices = findall(x -> x > 0, capital_choices)

        if !isempty(valid_choices)
            # Only evaluate the log on valid choices
            log_choices = log.(capital_choices[valid_choices]) .+ beta * v_prev[valid_choices]
            optimal_index = argmax(log_choices)
            g[i] = K[valid_choices[optimal_index]]
            v[i] = log(K[i] ^ alpha + (1 - delta) * K[i] - g[i]) + beta*v_prev[findfirst(==(g[i]), K)]
        else
            # If no valid choices, set to a default value
            g[i] = 0
            v[i] = -Inf # Log of zero / negative undefined, so set it to -Inf
        end
    end

    dist = maximum(abs.(v .- v_prev))
    v_prev = copy(v)

    push!(value_fns, copy(v))
    push!(policy_fns, copy(g))
end

using Plots

iterations_to_plot = [1, 2, 3, 4, 5, 10, 25, length(policy_fns) - 5, length(policy_fns) - 4, length(policy_fns) - 3, length(policy_fns) - 2, length(policy_fns) - 1, length(policy_fns)]

p = plot()
for i in iterations_to_plot
    plot!(p, K, value_fns[i], label = "Iteration $i", title = "Value Function Iteration", xlabel = "State Capital (k)", ylabel = "Value (v)")
end

plot!(p, legend=:outerright)
display(p)
savefig(p,"pset3_value_iteration.png")

p2 = plot()
for i in iterations_to_plot
    plot!(p2, K, policy_fns[i], label = "Iteration $i", title = "Policy Function Iteration", xlabel = "State Capital (k)", ylabel = "Chosen Capital (k')")
end

plot!(p2, legend=:outerright)
display(p2)
savefig(p2,"pset3_policy_iteration.png")
